import sys
import pysam
from numpy import *


dataset = sys.argv[1]

skip_targets = ('chrM', 'rRNA', 'tRNA', 'snRNA', 'scRNA', 'snoRNA', 'yRNA',
                'histone', 'scaRNA', 'snar', 'vRNA',
               )

keep_targets = ('RMRP', 'RPPH', 'TERC', 'MALAT1', 'snhg',
                'mRNA', 'lncRNA', 'gencode', 'fantomcat', 'genome')

skip_matures = ('spliced_mRNA', 'spliced_lncRNA', 'spliced_gencode')

skip_annotations = ('presnRNA', 'pretRNA', 'prescaRNA', 'presnoRNA',)


def write_bedfile(filename, counts):
    output = open(filename, "wt")
    for chromosome in counts:
        positions, strands = nonzero(counts[chromosome])
        for position, strand in zip(positions, strands):
            count = counts[chromosome][position, strand]
            strand = "+-"[strand]
            line = f"{chromosome}\t{position}\t{position+1}\t.\t{count}\t{strand}\n"
            output.write(line)
    output.close()


filename = "%s.sorted.bam" % dataset
print("Reading", filename)
bamfile = pysam.AlignmentFile(filename)

counts = {}
for line in bamfile.header['SQ']:
    chromosome = line['SN']
    length = line['LN']
    counts[chromosome] = zeros((length, 2))

index = -1
query_name = None
threshold = 1000000
for line in bamfile:
    if line.query_name != query_name:
        index += 1
        query_name = line.query_name
    if index == threshold:
        filename = "%s.%d.ctss.bed" % (dataset, threshold)
        print("Writing %s from %d sequences" % (filename, index))
        write_bedfile(filename, counts)
        threshold += 1000000
    if line.is_unmapped:
        continue
    target = line.get_tag("XT")
    if target in skip_targets:
        continue
    assert target in keep_targets
    try:
        annotation = line.get_tag("XA")
    except KeyError:
        pass
    else:
        assert annotation in skip_annotations
        continue
    if target in ("mRNA", "lncRNA", "gencode"):
        try:
            annotation = line.get_tag("XE")
        except KeyError:
            pass
        else:
            assert annotation in skip_matures
            continue
    try:
        length = line.get_tag("XL")
    except KeyError:
        pass
    else:
        if length < line.reference_length: # spliced
            continue
    multimap = line.get_tag("NH")
    if multimap != 1:
        continue
    if line.is_reverse:
        strand = 1
        position = line.aend - 1
    else:
        strand = 0
        position = line.pos
    chromosome = line.reference_name
    count = (1.0/multimap)
    counts[chromosome][position, strand] += count
bamfile.close()

index += 1
filename = "%s.%d.ctss.bed" % (dataset, index)
print("Writing %s from %d sequences" % (filename, index))
write_bedfile(filename, counts)
